import torch.nn as nn

from lib.Euclidean.blocks.resnet_blocks import BasicBlock, Bottleneck

from lib.lorentz.blocks.resnet_blocks import (
    LorentzBasicBlock,
    LorentzBottleneck,
    LorentzInputBlock,
)

from lib.lorentz.layers import LorentzMLR, LorentzGlobalAvgPool2d, LorentzTransform
from lib.lorentz.manifold import CustomLorentz

from lib.poincare.blocks.resnet_blocks import (
    PoincareBasicBlock,
    PoincareInputBlock,
)

from lib.poincare.layers import UnidirectionalPoincareMLR, PoincareGlobalAvgPool2d
from lib.poincare.manifold import CustomPoincare

__all__ = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]


class ManifoldSwapper(nn.Module):
    """ Implementation of ResNet models on manifolds. """

    def __init__(self, manifold, to_euclidean=False):
        super(ManifoldSwapper, self).__init__()

        self.manifold = manifold
        self.to_euclidean = to_euclidean
    def forward(self, x):

        if self.to_euclidean:
            return self.manifold.logmap0(x)[...,1:].permute(0, 3, 1, 2)

        x = x.permute(0,2,3,1)
        return self.manifold.projx(nn.functional.pad(x, pad=(1, 0)))


class ResNet(nn.Module):
    """ Implementation of ResNet models on manifolds. """

    def __init__(
        self,
        block,
        num_blocks,
        manifold=None,
        img_dim=[3,32,32],
        embed_dim=512,
        num_classes=100,
        bias=True,
        input_kernel=3,
        remove_linear=False,
    ):
        super(ResNet, self).__init__()

        self.img_dim = img_dim[0]
        self.in_channels = 64
        self.conv3_dim = 128
        self.conv4_dim = 256
        self.embed_dim = embed_dim

        self.input_kernel = input_kernel

        self.bias = bias
        self.block = block

        if type(block)==list:
            block_a = block[0]
            block_b = block[1]

            self.swap_1 = ManifoldSwapper(manifold, to_euclidean=False)
            self.swap_2 = ManifoldSwapper(manifold, to_euclidean=True)
            self.swap_3 = ManifoldSwapper(manifold, to_euclidean=False)
            self.swap_4 = ManifoldSwapper(manifold, to_euclidean=True)

        else:
            block_a = block_b = block
            self.swap_1 = self.swap_2 = self.swap_3 = self.swap_4 = nn.Sequential()

        self.manifold = manifold

        self.conv1 = self._get_inConv()
        self.conv2_x = self._make_layer(block_a, out_channels=self.in_channels, num_blocks=num_blocks[0], stride=1)
        self.conv3_x = self._make_layer(block_b, out_channels=self.conv3_dim, num_blocks=num_blocks[1], stride=2)
        self.conv4_x = self._make_layer(block_a, out_channels=self.conv4_dim, num_blocks=num_blocks[2], stride=2)
        self.conv5_x = self._make_layer(block_b, out_channels=self.embed_dim, num_blocks=num_blocks[3], stride=2)
        self.avg_pool = self._get_GlobalAveragePooling()

        if remove_linear:
            self.predictor = None
        else:
            self.predictor = self._get_predictor(self.embed_dim*block.expansion, num_classes)

    def forward(self, x, return_distances=False):
        out = self.conv1(x)
        out = self.swap_1(out)


        out_1 = self.conv2_x(out)
        out_1 = self.swap_2(out_1)

        out_2 = self.conv3_x(out_1)
        out_2 = self.swap_3(out_2)

        out_3 = self.conv4_x(out_2)
        out_3 = self.swap_4(out_3)

        out_4 = self.conv5_x(out_3)

        out = self.avg_pool(out_4)
        out = out.view(out.size(0), -1)

        # d = (self.manifold.dist0(out_1).mean([-1,-2]) +
        #      self.manifold.dist0(out_2).mean([-1,-2]) +
        #      self.manifold.dist0(out_3).mean([-1,-2]) +
        #      self.manifold.dist0(out_4).mean([-1,-2])) / 2

        if self.predictor is not None:
            out = self.predictor(out)

        if return_distances:
            return out, None

        return out

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []

        for stride in strides:
            if block in [BasicBlock, Bottleneck]:
                layers.append(block(self.in_channels, out_channels, stride, self.bias))
            elif (type(self.manifold) is CustomLorentz) or (type(self.manifold) is CustomPoincare):
                layers.append(
                    block(
                        self.manifold,
                        self.in_channels,
                        out_channels,
                        stride,
                        self.bias
                    )
                )
            else:
                raise RuntimeError(
                    f"Manifold {type(self.manifold)} not supported in ResNet."
                )

            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def _get_inConv(self):
        if self.manifold is None or type(self.block)==list:
            return nn.Sequential(
                nn.Conv2d(
                    self.img_dim,
                    self.in_channels,
                    kernel_size=self.input_kernel,
                    padding=1,
                    bias=self.bias
                ),
                nn.BatchNorm2d(self.in_channels),
                nn.ReLU(inplace=True),
            )

        elif type(self.manifold) is CustomLorentz:
            return LorentzInputBlock(
                self.manifold, 
                self.img_dim, 
                self.in_channels, 
                self.bias
            )
        
        elif type(self.manifold) is CustomPoincare:
            return PoincareInputBlock(
                self.manifold, 
                self.img_dim, 
                self.in_channels, 
                self.bias
            )

        else:
            raise RuntimeError(
                f"Manifold {type(self.manifold)} not supported in ResNet."
            )

    def _get_predictor(self, in_features, num_classes):
        if self.manifold is None:
            return nn.Linear(in_features, num_classes, bias=self.bias)

        elif type(self.manifold) is CustomLorentz:
            return LorentzMLR(self.manifold, in_features+1, num_classes)
        
        elif type(self.manifold) is CustomPoincare:
            return UnidirectionalPoincareMLR(self.manifold, in_features, num_classes)

        else:
            raise RuntimeError(f"Manifold {type(self.manifold)} not supported in ResNet.")

    def _get_GlobalAveragePooling(self):
        if self.manifold is None or type(self.block)==list:
            return nn.AdaptiveAvgPool2d((1, 1))

        elif type(self.manifold) is CustomLorentz:
            return LorentzGlobalAvgPool2d(self.manifold, keep_dim=True)
        
        elif type(self.manifold) is CustomPoincare:
            return PoincareGlobalAvgPool2d(self.manifold, keep_dim=True)

        else:
            raise RuntimeError(f"Manifold {type(self.manifold)} not supported in ResNet.")

#################################################
#       Lorentz
#################################################
def Lorentz_resnet10(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-10 model."""
    if not manifold:
        manifold = CustomLorentz(k=k, learnable=learn_k)
    model = ResNet(LorentzBasicBlock, [1, 1, 1, 1], manifold, **kwargs)
    return model


def Lorentz_resnet18(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-18 model."""
    if not manifold:
        manifold = CustomLorentz(k=k, learnable=learn_k)
    model = ResNet(LorentzBasicBlock, [2, 2, 2, 2], manifold, **kwargs)
    return model


def Lorentz_resnet34(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-34 model."""
    if not manifold:
        manifold = CustomLorentz(k=k, learnable=learn_k)
    model = ResNet(LorentzBasicBlock, [3, 4, 6, 3], manifold, **kwargs)
    return model


def Lorentz_resnet50(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-50 model."""
    if not manifold:
        manifold = CustomLorentz(k=k, learnable=learn_k)
    model = ResNet(LorentzBottleneck, [3, 4, 6, 3], manifold, **kwargs)
    return model


def Lorentz_resnet101(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-101 model."""
    if not manifold:
        manifold = CustomLorentz(k=k, learnable=learn_k)
    model = ResNet(LorentzBottleneck, [3, 4, 23, 3], manifold, **kwargs)
    return model


def Lorentz_resnet152(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-152 model."""
    if not manifold:
        manifold = CustomLorentz(k=k, learnable=learn_k)
    model = ResNet(LorentzBottleneck, [3, 8, 36, 3], manifold, **kwargs)
    return model


#################################################
#       Poincare
#################################################
def Poincare_resnet10(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-10 model."""
    if not manifold:
        manifold = CustomPoincare(c=k, learnable=learn_k)
    model = ResNet(PoincareBasicBlock, [1, 1, 1, 1], manifold, **kwargs)
    return model


def Poincare_resnet18(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-18 model."""
    if not manifold:
        manifold = CustomPoincare(c=k, learnable=learn_k)
    model = ResNet(PoincareBasicBlock, [2, 2, 2, 2], manifold, **kwargs)
    return model


def Poincare_resnet34(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-34 model."""
    if not manifold:
        manifold = CustomPoincare(c=k, learnable=learn_k)
    model = ResNet(PoincareBasicBlock, [3, 4, 6, 3], manifold, **kwargs)
    return model


#################################################
#       Euclidean
#################################################
def resnet10(**kwargs):
    """Constructs a ResNet-10 model."""
    model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
    return model


def resnet18(**kwargs):
    """Constructs a ResNet-18 model."""
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


def resnet34(**kwargs):
    """Constructs a ResNet-34 model."""
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    return model


def resnet50(**kwargs):
    """Constructs a ResNet-50 model."""
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    return model


def resnet101(**kwargs):
    """Constructs a ResNet-101 model."""
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    return model


def resnet152(**kwargs):
    """Constructs a ResNet-152 model."""
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model


def Mixed_resnet18(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-18 model."""
    if not manifold:
        manifold = CustomLorentz(k=k, learnable=learn_k)

    blocks = [LorentzBasicBlock, BasicBlock]

    model = ResNet(blocks, [2, 2, 2, 2], manifold, **kwargs)
    return model


def Mixed_resnet50(k=1, learn_k=False, manifold=None, **kwargs):
    """Constructs a ResNet-18 model."""
    if not manifold:
        manifold = CustomLorentz(k=k, learnable=learn_k)

    blocks = [BasicBlock, LorentzBasicBlock]

    model = ResNet(blocks, [2, 2, 2, 2], manifold, **kwargs)
    return model
